package com.nx.platform.es.bean.modle.rescore;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.nx.platform.es.bean.modle.score.ScoreFunctionField;
import com.nx.platform.es.common.utils.MoreMaps;
import org.apache.commons.collections4.MapUtils;
import org.apache.lucene.search.Query;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
import org.elasticsearch.search.rescore.RescorerBuilder;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/**
 * @author
 * @since 2016年10月16日
 */
public class SltrHandler implements RescoreFieldHandler {

    @Override
    public RescorerBuilder<?> handle(ImmutableMap<String, ?> fieldConfig, List<ScoreFunctionField> functions,
            Map<String, Object> params) {
        String model = MapUtils.getString(fieldConfig, "model");
        Preconditions.checkState(!Strings.isNullOrEmpty(model));
        Set<String> paramNames = MoreMaps.getObject(fieldConfig, "paramNames");
        if (params == null || paramNames == null || !params.keySet().containsAll(paramNames)) {
            return null;
        }
        // Map<String, Object> source = ImmutableMap.of("sltr", ImmutableMap.of("model", model, "params", params));
        // WrapperQueryBuilder queryBuilder = QueryBuilders.wrapperQuery(GSON.toJson(source));
        SltrQueryBuilder queryBuilder = new SltrQueryBuilder(model, params);
        QueryRescorerBuilder rescorer = new QueryRescorerBuilder(queryBuilder);
        float queryWeight = MapUtils.getFloatValue(fieldConfig, "query_weight", 1f);
        if (Float.floatToIntBits(queryWeight) != 1) {
            rescorer.setQueryWeight(queryWeight);
        }
        float rescoreQueryWeight = MapUtils.getFloatValue(fieldConfig, "rescore_query_weight", 1f);
        if (Float.floatToIntBits(rescoreQueryWeight) != 1) {
            rescorer.setRescoreQueryWeight(rescoreQueryWeight);
        }
        return rescorer;
    }

    public static class SltrQueryBuilder extends AbstractQueryBuilder<SltrQueryBuilder> {

        private final String model;
        private final Map<String, Object> params;

        public SltrQueryBuilder(String model, Map<String, Object> params) {
            this.model = model;
            this.params = params;
        }

        @Override
        protected void doWriteTo(StreamOutput out) throws IOException {
            throw new UnsupportedOperationException("sltr doWriteTo");
        }

        @Override
        protected void doXContent(XContentBuilder builder, Params params) throws IOException {
            builder.startObject(getWriteableName());
            builder.field("model", this.model);
            builder.field("params", this.params);
            builder.endObject();
        }

        @Override
        protected Query doToQuery(QueryShardContext context) throws IOException {
            throw new UnsupportedOperationException("sltr doToQuery");
        }

        @Override
        protected boolean doEquals(SltrQueryBuilder other) {
            return Objects.equals(model, other.model) &&
                    Objects.equals(params, other.params);
        }

        @Override
        protected int doHashCode() {
            return Objects.hash(model, params);
        }

        @Override
        public String getWriteableName() {
            return "sltr";
        }

    }

}
